### Compare our predictions of war outcomes to alternative models with different
### capability aggregation rules for coalitions

date()
library("tidyverse")
library("foreach")
source("backend_main.r")
sessionInfo()


load("results/fit_base.rda")
load("results/boot_base.rda")


## Separate coalition from non-coalition wars
wars_data <- fit_base[[1]]$data_dispute %>% filter(war == 1)
idx_coa <- with(wars_data, n_states_a + n_states_b > 2)
idx_non <- !idx_coa
id_non <- wars_data$id[idx_non]

## Base model
lik_structwar <- foreach (imp = fit_base) %do% {
  setup <- with(imp, structwar_setup(f_dispute = f_dispute,
                                     f_participant = f_participant,
                                     data_dispute = data_dispute,
                                     data_participant = data_participant,
                                     xlev_dispute = xlev_dispute,
                                     xlev_participant = xlev_participant,
                                     n_halton = n_halton))
  parms <- extract_params(est = coef(imp$fit),
                          setup = setup,
                          for_counterfactuals = TRUE)
  cvals <- contest_vals(dispute_level_id = setup$dispute_level_id,
                        state_level_id = setup$state_level_id,
                        side_a = setup$side_a,
                        ratio = parms$ratio,
                        xmax = .Machine$double.xmax)
  prob_win_a <- cvals$prob_win_a
  lik <- if_else(imp$data_dispute$win_a == 1, prob_win_a, 1.0 - prob_win_a)
  lik <- lik[imp$data_dispute$war == 1]
  lik
}
lik_structwar <- do.call("rbind", lik_structwar)
lik_structwar <- colMeans(lik_structwar)
ll_structwar <- log(lik_structwar)
dat_structwar <- map(fit_base, ~ coef(.x$fit)) %>%
  do.call("rbind", .) %>%
  as_tibble() %>%
  gather(key = "term", value = "value") %>%
  group_by(term) %>%
  summarise(est = mean(value))
dat_boot_se <- boot_base %>%
  do.call("rbind", .) %>%
  as_tibble() %>%
  gather(key = "term", value = "value") %>%
  group_by(term) %>%
  summarise(se = sd(value))
dat_structwar <-
  left_join(dat_structwar, dat_boot_se, by = "term") %>%
  filter(str_detect(term, "^(beta|gamma)")) %>%
  mutate(est = if_else(str_detect(term, "^gamma"), -1.0 * est, est),
         term = str_replace(term, "^(beta|gamma):", ""),
         term = str_replace(term, "log1p\\(", ""),
         term = str_replace(term, "log\\(", ""),
         term = str_replace(term, "\\)", ""),
         term = paste0(term, "_diff")) %>%
  gather(key = "name", value = "structural", c(est, se)) %>%
  mutate(term = factor(term, levels = c("gdp_pwt_diff", "irst_diff", "pec_diff", "tpop_diff", "upop_diff", "distance_diff", "nuclear_diff", "pct_imports_diff", "polity2_diff"))) %>%
  arrange(term, name) %>%
  mutate(term = paste0(term, "_", name)) %>%
  select(-name) %>%
  add_row(term = "n_war", structural = length(ll_structwar)) %>%
  add_row(term = "ll_bilateral", structural = mean(ll_structwar[idx_non])) %>%
  add_row(term = "ll_coalition", structural = mean(ll_structwar[idx_coa]))

## Calculate columns for use in logit to best match the contest model in bilateral cases
make_logit_cols <- function(dat) {
  dat %>%
    mutate(sidea = if_else(sidea == 1, "a", "b")) %>%
    gather(key = "name", value = "value", c(-id, -sidea)) %>%
    mutate(variable = paste(name, sidea, sep = "_")) %>%
    select(-sidea, -name) %>%
    spread(key = "variable", value = "value") %>%
    transmute(id = id,
              gdp_pwt_diff = log(gdp_pwt_a) - log(gdp_pwt_b),
              irst_diff = log1p(irst_a) - log1p(irst_b),
              pec_diff = log1p(pec_a) - log1p(pec_b),
              tpop_diff = log1p(tpop_a) - log1p(tpop_b),
              upop_diff = log1p(upop_a) - log1p(upop_b),
              distance_diff = log1p(distance_a) - log1p(distance_b),
              nuclear_diff = nuclear_a - nuclear_b,
              pct_imports_diff = log1p(pct_imports_a) - log1p(pct_imports_b),
              polity2_diff = polity2_a - polity2_b)
}

## Calculate MI coefficients, SEs, and log-likelihoods from a list of logit models
mi_from_list <- function(mods, val_name = NULL) {
  ## Average coefficients
  l_coef <- map(mods, coef)
  l_coef <- do.call("rbind", l_coef)
  cf <- colMeans(l_coef)

  ## Standard errors
  l_var <- map(mods, ~ diag(vcov(.x)))
  l_var <- do.call("rbind", l_var)
  v_within <- colMeans(l_var)
  v_across <- apply(l_coef, 2, var)
  m <- length(mods)
  v <- v_within + (m + 1) * v_across / m
  se <- sqrt(v)

  ## Construct data frame
  df_out <- tibble(term = names(cf), est = cf, se = se) %>%
    gather(key = "name", value = "value", c(est, se)) %>%
    mutate(term = paste0(term, "_", name)) %>%
    select(-name)

  ## Individual log-likelihoods
  y <- unname(mods[[1]]$y)
  l_pred <- map(mods, ~ unname(predict(.x, type = "response")))
  l_lik <- map(l_pred, ~ if_else(y == 1, .x, 1.0 - .x))
  l_lik <- do.call("rbind", l_lik)
  lik <- colMeans(l_lik)
  ll <- log(lik)
  df_out <- df_out %>%
    add_row(term = "n_war", value = length(ll))
  if (length(ll) == length(idx_coa)) {
    df_out <- df_out %>%
      add_row(term = "ll_bilateral", value = mean(ll[idx_non])) %>%
      add_row(term = "ll_coalition", value = mean(ll[idx_coa]))
  } else if (length(ll) == sum(idx_non)) {
    df_out <- df_out %>%
      add_row(term = "ll_bilateral", value = mean(ll)) %>%
      add_row(term = "ll_coalition", value = NA)
  }

  if (!is.null(val_name))
    names(df_out)[2] <- val_name

  df_out
}

## Run logits for model where we just sum up capabilities on each side
logit_sum <- foreach (imp = fit_base) %do% {
  d_participant <- imp$data_participant %>%
    select(id, sidea, gdp_pwt, irst, pec, tpop, upop, distance, nuclear, pct_imports, polity2) %>%
    group_by(id, sidea) %>%
    summarise(gdp_pwt = sum(gdp_pwt),
              irst = sum(irst),
              pec = sum(pec),
              tpop = sum(tpop),
              upop = sum(upop),
              distance = mean(distance),
              nuclear = sum(nuclear),
              pct_imports = mean(pct_imports),
              polity2 = mean(polity2)) %>%
    ungroup() %>%
    make_logit_cols()
  d_dispute <- imp$data_dispute[imp$data_dispute$war == 1, ] %>%
    select(id, win_a) %>%
    left_join(d_participant, by = "id")
  fit_logit <- glm(win_a ~ gdp_pwt_diff + irst_diff + pec_diff + tpop_diff +
                     upop_diff + distance_diff + nuclear_diff + pct_imports_diff +
                     polity2_diff - 1,
                   data = d_dispute,
                   family = binomial(link = "logit"))
  fit_logit
}
dat_sum <- mi_from_list(logit_sum, "logit_sum")

## Do a similar thing but for max free riding --- just take the values for the coalition
## member with the highest cinc score
logit_max <- foreach (imp = fit_base) %do% {
  d_participant <- imp$data_participant %>%
    select(id, sidea, gdp_pwt, irst, pec, tpop, upop, distance, nuclear, pct_imports, polity2, cinc) %>%
    group_by(id, sidea) %>%
    filter(cinc == max(cinc)) %>%
    ungroup() %>%
    make_logit_cols()
  d_dispute <- imp$data_dispute[imp$data_dispute$war == 1, ] %>%
    select(id, win_a) %>%
    left_join(d_participant, by = "id")
  fit_logit <- glm(win_a ~ gdp_pwt_diff + irst_diff + pec_diff + tpop_diff +
                     upop_diff + distance_diff + nuclear_diff + pct_imports_diff +
                     polity2_diff - 1,
                   data = d_dispute,
                   family = binomial(link = "logit"))
  fit_logit
}
dat_max <- mi_from_list(logit_max, "logit_max")

## Fit models purely to the set of bilateral disputes
logit_bi <- foreach (imp = fit_base) %do% {
  d_participant <- imp$data_participant %>%
    select(id, sidea, gdp_pwt, irst, pec, tpop, upop, distance, nuclear, pct_imports, polity2, cinc) %>%
    filter(id %in% !! id_non) %>%
    make_logit_cols()
  d_dispute <- imp$data_dispute %>%
    filter(id %in% !! id_non) %>%
    select(id, win_a) %>%
    left_join(d_participant, by = "id")
  fit_logit <- glm(win_a ~ gdp_pwt_diff + irst_diff + pec_diff + tpop_diff +
                     upop_diff + distance_diff + nuclear_diff + pct_imports_diff +
                     polity2_diff - 1,
                   data = d_dispute,
                   family = binomial(link = "logit"))
  fit_logit
}
dat_bi <- mi_from_list(logit_bi, "logit_bi")

## Put the models together
dat_full <- dat_structwar %>%
  left_join(dat_bi, by = "term") %>%
  left_join(dat_sum, by = "term") %>%
  left_join(dat_max, by = "term") %>%
  select(term, logit_bi, logit_sum, logit_max, structural) %>%
  mutate_at(vars(-term), ~ sprintf("%.3f", .x)) %>%
  mutate_at(vars(-term), ~ if_else(term == "n_war", str_replace(.x, "\\..*$", ""), .x)) %>%
  mutate_at(vars(-term), ~ str_replace(.x, "^\\-", "$-$")) %>%
  mutate_at(vars(-term), ~ if_else(.x == "NA", "--", .x)) %>%
  mutate_at(vars(-term), ~ if_else(str_detect(term, "_se$"), paste0("(", .x, ")"), .x)) %>%
  mutate(linebr = if_else(str_detect(term, "_se$"), "\\\\[0.2em]", "\\\\"),
         midrule = if_else(term == "polity2_diff_se", "\\midrule", ""),
         term = if_else(str_detect(term, "_se$"), "", str_replace(term, "_diff_est", "")),
         term = recode(term,
                       gdp_pwt = "GDP",
                       irst = "Iron and Steel",
                       pec = "Energy Consumption",
                       tpop = "Total Population",
                       upop = "Urban Population",
                       distance = "Distance to Dispute",
                       nuclear = "Nuclear Weapons",
                       pct_imports = "Import Percentage",
                       polity2 = "Democracy",
                       n_war = "No.\\ Wars",
                       ll_bilateral = "LL: Bilateral Wars",
                       ll_coalition = "LL: Coalition Wars"),
         glue = str_glue("{term} & {logit_bi} & {logit_sum} & {logit_max} & {structural} {linebr} {midrule}"))

## Full table for the appendix
tab_full <- c(
  "\\begin{tabular}{lcccc}",
  "\\toprule",
  "& Bilateral & Summed       & Max          & Equilibrium \\\\",
  "& Only      & Capabilities & Capabilities & Model \\\\",
  "\\midrule",
  dat_full$glue,
  "\\bottomrule",
  "\\end{tabular}"
)

if (!dir.exists("tables"))
  dir.create("tables")
writeLines(tab_full, con = "tables/table_A8.tex")


date()
